{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.tft"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TFT"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In summary Temporal Fusion Transformer (TFT) combines gating layers, an LSTM recurrent encoder, with multi-head attention layers for a multi-step forecasting strategy decoder.<br>TFT's inputs are static exogenous $\\mathbf{x}^{(s)}$, historic exogenous $\\mathbf{x}^{(h)}_{[:t]}$, exogenous available at the time of the prediction $\\mathbf{x}^{(f)}_{[:t+H]}$ and autorregresive features $\\mathbf{y}_{[:t]}$, each of these inputs is further decomposed into categorical and continuous. The network uses a multi-quantile regression to model the following conditional probability:$$\\mathbb{P}(\\mathbf{y}_{[t+1:t+H]}|\\;\\mathbf{y}_{[:t]},\\; \\mathbf{x}^{(h)}_{[:t]},\\; \\mathbf{x}^{(f)}_{[:t+H]},\\; \\mathbf{x}^{(s)})$$\n",
    "\n",
    "**References**<br>\n",
    "- [Jan Golda, Krzysztof Kudrynski. \"NVIDIA, Deep Learning Forecasting Examples\"](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Forecasting/TFT)<br>\n",
    "- [Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister, \"Temporal Fusion Transformers for interpretable multi-horizon time series forecasting\"](https://www.sciencedirect.com/science/article/pii/S0169207021000637)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. Temporal Fusion Transformer Architecture.](imgs_models/tft_architecture.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "from torch.nn import LayerNorm\n",
    "\n",
    "import logging\n",
    "import warnings\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from typing import Tuple, Optional\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE\n",
    "from neuralforecast.common._base_windows import BaseWindows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import logging\n",
    "import warnings\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Auxiliary Functions"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 Gating Mechanisms\n",
    "\n",
    "The Gated Residual Network (GRN) provides adaptive depth and network complexity capable of accommodating different size datasets. As residual connections allow for the network to skip the non-linear transformation of input $\\mathbf{a}$ and context $\\mathbf{c}$.\n",
    "\n",
    "\\begin{align}\n",
    "\\eta_{1} &= \\mathrm{ELU}(\\mathbf{W}_{1}\\mathbf{a}+\\mathbf{W}_{2}\\mathbf{c}+\\mathbf{b}_{1}) \\\\\n",
    "\\eta_{2} &= \\mathbf{W}_{2}\\eta_{1}+b_{2} \\\\\n",
    "\\mathrm{GRN}(\\mathbf{a}, \\mathbf{c}) &= \\mathrm{LayerNorm}(a + \\textrm{GLU}(\\eta_{2}))\n",
    "\\end{align}\n",
    "\n",
    "The Gated Linear Unit (GLU) provides the flexibility of supressing unnecesary parts of the GRN. Consider GRN's output $\\gamma$ then GLU transformation is defined by:\n",
    "\n",
    "$$\\mathrm{GLU}(\\gamma) = \\sigma(\\mathbf{W}_{4}\\gamma +b_{4}) \\odot (\\mathbf{W}_{5}\\gamma +b_{5})$$"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 2. Gated Residual Network.](imgs_models/tft_grn.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class MaybeLayerNorm(nn.Module):\n",
    "    def __init__(self, output_size, hidden_size, eps):\n",
    "        super().__init__()\n",
    "        if output_size and output_size == 1:\n",
    "            self.ln = nn.Identity()\n",
    "        else:\n",
    "            self.ln = LayerNorm(output_size if output_size else hidden_size,\n",
    "                                eps=eps)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.ln(x)\n",
    "\n",
    "class GLU(nn.Module):\n",
    "    def __init__(self, hidden_size, output_size):\n",
    "        super().__init__()\n",
    "        self.lin = nn.Linear(hidden_size, output_size * 2)\n",
    "\n",
    "    def forward(self, x: Tensor) -> Tensor:\n",
    "        x = self.lin(x)\n",
    "        x = F.glu(x)\n",
    "        return x\n",
    "\n",
    "class GRN(nn.Module):\n",
    "    def __init__(self,\n",
    "                 input_size,\n",
    "                 hidden_size, \n",
    "                 output_size=None,\n",
    "                 context_hidden_size=None,\n",
    "                 dropout=0):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3)\n",
    "        self.lin_a = nn.Linear(input_size, hidden_size)\n",
    "        if context_hidden_size is not None:\n",
    "            self.lin_c = nn.Linear(context_hidden_size, hidden_size, bias=False)\n",
    "        self.lin_i = nn.Linear(hidden_size, hidden_size)\n",
    "        self.glu = GLU(hidden_size, output_size if output_size else hidden_size)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.out_proj = nn.Linear(input_size, output_size) if output_size else None\n",
    "\n",
    "    def forward(self, a: Tensor, c: Optional[Tensor] = None):\n",
    "        x = self.lin_a(a)\n",
    "        if c is not None:\n",
    "            x = x + self.lin_c(c).unsqueeze(1)\n",
    "        x = F.elu(x)\n",
    "        x = self.lin_i(x)\n",
    "        x = self.dropout(x)\n",
    "        x = self.glu(x)\n",
    "        y = a if not self.out_proj else self.out_proj(a)\n",
    "        x = x + y\n",
    "        x = self.layer_norm(x)\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Variable Selection Networks\n",
    "\n",
    "TFT includes automated variable selection capabilities, through its variable selection network (VSN) components. The VSN takes the original input $\\{\\mathbf{x}^{(s)}, \\mathbf{x}^{(h)}_{[:t]}, \\mathbf{x}^{(f)}_{[:t]}\\}$ and transforms it through embeddings or linear transformations into a high dimensional space\n",
    "$\\{\\mathbf{E}^{(s)}, \\mathbf{E}^{(h)}_{[:t]}, \\mathbf{E}^{(f)}_{[:t+H]}\\}$. \n",
    "\n",
    "For the observed historic data, the embedding matrix $\\mathbf{E}^{(h)}_{t}$ at time $t$ is a concatenation of $j$ variable $e^{(h)}_{t,j}$ embeddings:\n",
    "\\begin{align}\n",
    "\\mathbf{E}^{(h)}_{t} &= [e^{(h)}_{t,1},\\dots,e^{(h)}_{t,j},\\dots,e^{(h)}_{t,n_{h}}] \\\\\n",
    "\\mathbf{\\tilde{e}}^{(h)}_{t,j} &= \\mathrm{GRN}(e^{(h)}_{t,j})\n",
    "\\end{align}\n",
    "\n",
    "The variable selection weights are given by:\n",
    "$$s^{(h)}_{t}=\\mathrm{SoftMax}(\\mathrm{GRN}(\\mathbf{E}^{(h)}_{t},\\mathbf{E}^{(s)}))$$\n",
    "\n",
    "The VSN processed features are then:\n",
    "$$\\tilde{\\mathbf{E}}^{(h)}_{t}= \\sum_{j} s^{(h)}_{j} \\tilde{e}^{(h)}_{t,j}$$"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 3. Variable Selection Network.](imgs_models/tft_vsn.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class TFTEmbedding(nn.Module):\n",
    "    def __init__(self, hidden_size, stat_input_size, futr_input_size, hist_input_size, tgt_size):\n",
    "        super().__init__()\n",
    "        # There are 4 types of input:\n",
    "        # 1. Static continuous\n",
    "        # 2. Temporal known a priori continuous\n",
    "        # 3. Temporal observed continuous\n",
    "        # 4. Temporal observed targets (time series obseved so far)\n",
    "\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.stat_input_size = stat_input_size\n",
    "        self.futr_input_size = futr_input_size\n",
    "        self.hist_input_size = hist_input_size\n",
    "        self.tgt_size        = tgt_size\n",
    "\n",
    "        # Instantiate Continuous Embeddings if size is not None\n",
    "        for attr, size in [('stat_exog_embedding', stat_input_size), \n",
    "                           ('futr_exog_embedding', futr_input_size),\n",
    "                           ('hist_exog_embedding', hist_input_size),\n",
    "                           ('tgt_embedding', tgt_size)]:\n",
    "            if size:\n",
    "                vectors = nn.Parameter(torch.Tensor(size, hidden_size))\n",
    "                bias = nn.Parameter(torch.zeros(size, hidden_size))\n",
    "                torch.nn.init.xavier_normal_(vectors)\n",
    "                setattr(self, attr+'_vectors', vectors)\n",
    "                setattr(self, attr+'_bias', bias)\n",
    "            else:\n",
    "                setattr(self, attr+'_vectors', None)\n",
    "                setattr(self, attr+'_bias', None)\n",
    "\n",
    "    def _apply_embedding(self,\n",
    "                         cont: Optional[Tensor],\n",
    "                         cont_emb: Tensor,\n",
    "                         cont_bias: Tensor,\n",
    "                         ):\n",
    "\n",
    "        if (cont is not None):\n",
    "            #the line below is equivalent to following einsums\n",
    "            #e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb)\n",
    "            #e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb)          \n",
    "            e_cont = torch.mul(cont.unsqueeze(-1), cont_emb)\n",
    "            e_cont = e_cont + cont_bias\n",
    "            return e_cont\n",
    "        \n",
    "        return None\n",
    "\n",
    "    def forward(self, target_inp, \n",
    "                stat_exog=None, futr_exog=None, hist_exog=None):\n",
    "        # temporal/static categorical/continuous known/observed input \n",
    "        # tries to get input, if fails returns None\n",
    "\n",
    "        # Static inputs are expected to be equal for all timesteps\n",
    "        # For memory efficiency there is no assert statement\n",
    "        stat_exog = stat_exog[:,:] if stat_exog is not None else None\n",
    "\n",
    "        s_inp = self._apply_embedding(cont=stat_exog,\n",
    "                                      cont_emb=self.stat_exog_embedding_vectors,\n",
    "                                      cont_bias=self.stat_exog_embedding_bias)\n",
    "        k_inp = self._apply_embedding(cont=futr_exog,\n",
    "                                      cont_emb=self.futr_exog_embedding_vectors,\n",
    "                                      cont_bias=self.futr_exog_embedding_bias)\n",
    "        o_inp = self._apply_embedding(cont=hist_exog,\n",
    "                                      cont_emb=self.hist_exog_embedding_vectors,\n",
    "                                      cont_bias=self.hist_exog_embedding_bias)\n",
    "\n",
    "        # Temporal observed targets\n",
    "        # t_observed_tgt = torch.einsum('btf,fh->btfh', \n",
    "        #                               target_inp, self.tgt_embedding_vectors)        \n",
    "        target_inp = torch.matmul(target_inp.unsqueeze(3).unsqueeze(4),\n",
    "                          self.tgt_embedding_vectors.unsqueeze(1)).squeeze(3)\n",
    "        target_inp = target_inp + self.tgt_embedding_bias\n",
    "\n",
    "        return s_inp, k_inp, o_inp, target_inp\n",
    "\n",
    "class VariableSelectionNetwork(nn.Module):\n",
    "    def __init__(self, hidden_size, num_inputs, dropout):\n",
    "        super().__init__()\n",
    "        self.joint_grn = GRN(input_size=hidden_size*num_inputs, \n",
    "                             hidden_size=hidden_size, \n",
    "                             output_size=num_inputs, \n",
    "                             context_hidden_size=hidden_size)\n",
    "        self.var_grns = nn.ModuleList(\n",
    "                        [GRN(input_size=hidden_size, \n",
    "                             hidden_size=hidden_size, dropout=dropout)\n",
    "                         for _ in range(num_inputs)])\n",
    "\n",
    "    def forward(self, x: Tensor, context: Optional[Tensor] = None):\n",
    "        Xi = x.reshape(*x.shape[:-2], -1)\n",
    "        grn_outputs = self.joint_grn(Xi, c=context)\n",
    "        sparse_weights = F.softmax(grn_outputs, dim=-1)\n",
    "        transformed_embed_list = [m(x[...,i,:])\n",
    "                                     for i, m in enumerate(self.var_grns)]\n",
    "        transformed_embed = torch.stack(transformed_embed_list, dim=-1)\n",
    "        #the line below performs batched matrix vector multiplication\n",
    "        #for temporal features it's bthf,btf->bth\n",
    "        #for static features it's bhf,bf->bh\n",
    "        variable_ctx = torch.matmul(transformed_embed, \n",
    "                                    sparse_weights.unsqueeze(-1)).squeeze(-1)\n",
    "\n",
    "        return variable_ctx, sparse_weights"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3. Multi-Head Attention\n",
    "\n",
    "To avoid information bottlenecks from the classic Seq2Seq architecture, TFT \n",
    "incorporates a decoder-encoder attention mechanism inherited transformer architectures ([Li et. al 2019](https://arxiv.org/abs/1907.00235), [Vaswani et. al 2017](https://arxiv.org/abs/1706.03762)). It transform the the outputs of the LSTM encoded temporal features, and helps the decoder better capture long-term relationships.\n",
    "\n",
    "The original multihead attention for each component $H_{m}$ and its query, key, and value representations are denoted by $Q_{m}, K_{m}, V_{m}$, its transformation is given by:\n",
    "\n",
    "\\begin{align}\n",
    "Q_{m} = Q W_{Q,m} \\quad K_{m} = K W_{K,h} \\quad V_{m} = V W_{V,m} \\\\\n",
    "H_{m}=\\mathrm{Attention}(Q_{m}, K_{m}, V_{m}) = \\mathrm{SoftMax}(Q_{m} K^{\\intercal}_{m}/\\mathrm{scale}) \\; V_{m} \\\\\n",
    "\\mathrm{MultiHead}(Q, K, V) = [H_{1},\\dots,H_{M}] W_{M}\n",
    "\\end{align}\n",
    "\n",
    "TFT modifies the original multihead attention to improve its interpretability. To do it it uses shared values $\\tilde{V}$ across heads and employs additive aggregation, $\\mathrm{InterpretableMultiHead}(Q,K,V) = \\tilde{H} W_{M}$. The mechanism has a great resemblence to a single attention layer, but it allows for $M$ multiple attention weights, and can be therefore be interpreted as the average ensemble of $M$ single attention layers.\n",
    "\n",
    "\\begin{align}\n",
    "\\tilde{H} &= \\left(\\frac{1}{M} \\sum_{m} \\mathrm{SoftMax}(Q_{m} K^{\\intercal}_{m}/\\mathrm{scale}) \\right) \\tilde{V} \n",
    "          = \\frac{1}{M} \\sum_{m} \\mathrm{Attention}(Q_{m}, K_{m}, \\tilde{V}) \\\\\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class InterpretableMultiHeadAttention(nn.Module):\n",
    "    def __init__(self, n_head, hidden_size, example_length,\n",
    "                 attn_dropout, dropout):\n",
    "        super().__init__()\n",
    "        self.n_head = n_head\n",
    "        assert hidden_size % n_head == 0\n",
    "        self.d_head = hidden_size // n_head\n",
    "        self.qkv_linears = nn.Linear(hidden_size, \n",
    "                                     (2 * self.n_head + 1) * self.d_head,\n",
    "                                     bias=False)\n",
    "        self.out_proj = nn.Linear(self.d_head, hidden_size, bias=False)\n",
    "\n",
    "        self.attn_dropout = nn.Dropout(attn_dropout)\n",
    "        self.out_dropout = nn.Dropout(dropout)\n",
    "        self.scale = self.d_head**-0.5\n",
    "        self.register_buffer(\"_mask\",\n",
    "          torch.triu(torch.full((example_length, example_length), \n",
    "                                float('-inf')), 1).unsqueeze(0))\n",
    "\n",
    "    def forward(self, x: Tensor, \n",
    "                mask_future_timesteps: bool = True) -> Tuple[Tensor, Tensor]:\n",
    "        # [Batch,Time,MultiHead,AttDim] := [N,T,M,AD]\n",
    "        bs, t, h_size = x.shape\n",
    "        qkv = self.qkv_linears(x)\n",
    "        q, k, v = qkv.split((self.n_head * self.d_head, \n",
    "                             self.n_head * self.d_head, self.d_head), dim=-1)\n",
    "        q = q.view(bs, t, self.n_head, self.d_head)\n",
    "        k = k.view(bs, t, self.n_head, self.d_head)\n",
    "        v = v.view(bs, t, self.d_head)\n",
    "        \n",
    "        # [N,T1,M,Ad] x [N,T2,M,Ad] -> [N,M,T1,T2]\n",
    "        # attn_score = torch.einsum('bind,bjnd->bnij', q, k)\n",
    "        attn_score = torch.matmul(q.permute((0, 2, 1, 3)), \n",
    "                                  k.permute((0, 2, 3, 1)))\n",
    "        attn_score.mul_(self.scale)\n",
    "\n",
    "        if mask_future_timesteps:\n",
    "            attn_score = attn_score + self._mask\n",
    "\n",
    "        attn_prob = F.softmax(attn_score, dim=3)\n",
    "        attn_prob = self.attn_dropout(attn_prob)\n",
    "\n",
    "        # [N,M,T1,T2] x [N,M,T1,Ad] -> [N,M,T1,Ad]\n",
    "        # attn_vec = torch.einsum('bnij,bjd->bnid', attn_prob, v)\n",
    "        attn_vec = torch.matmul(attn_prob, v.unsqueeze(1))\n",
    "        m_attn_vec = torch.mean(attn_vec, dim=1)\n",
    "        out = self.out_proj(m_attn_vec)\n",
    "        out = self.out_dropout(out)\n",
    "\n",
    "        return out, attn_vec"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. TFT Architecture\n",
    "\n",
    "The first TFT's step is embed the original input $\\{\\mathbf{x}^{(s)}, \\mathbf{x}^{(h)}, \\mathbf{x}^{(f)}\\}$ into a high dimensional space $\\{\\mathbf{E}^{(s)}, \\mathbf{E}^{(h)}, \\mathbf{E}^{(f)}\\}$, after which each embedding is gated by a variable selection network (VSN). The static embedding $\\mathbf{E}^{(s)}$ is used as context for variable selection and as initial condition to the LSTM. Finally the encoded variables are fed into the multi-head attention decoder.\n",
    "\n",
    "\\begin{align}\n",
    " c_{s}, c_{e}, (c_{h}, c_{c}) &=\\textrm{StaticCovariateEncoder}(\\mathbf{E}^{(s)}) \\\\ \n",
    "      h_{[:t]}, h_{[t+1:t+H]}  &=\\textrm{TemporalCovariateEncoder}(\\mathbf{E}^{(h)}, \\mathbf{E}^{(f)}, c_{h}, c_{c}) \\\\\n",
    "\\hat{\\mathbf{y}}^{(q)}_{[t+1:t+H]} &=\\textrm{TemporalFusionDecoder}(h_{[t+1:t+H]}, c_{e})\n",
    "\\end{align}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Static Covariate Encoder\n",
    "\n",
    "The static embedding $\\mathbf{E}^{(s)}$ is transformed by the StaticCovariateEncoder into contexts $c_{s}, c_{e}, c_{h}, c_{c}$. Where $c_{s}$ are temporal variable selection contexts, $c_{e}$ are TemporalFusionDecoder enriching contexts, and $c_{h}, c_{c}$ are LSTM's hidden/contexts for the TemporalCovariateEncoder.\n",
    "\n",
    "\\begin{align}\n",
    "c_{s}, c_{e}, (c_{h}, c_{c}) & = \\textrm{GRN}(\\textrm{VSN}(\\mathbf{E}^{(s)}))\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class StaticCovariateEncoder(nn.Module):\n",
    "    def __init__(self, hidden_size, num_static_vars, dropout):\n",
    "        super().__init__()\n",
    "        self.vsn = VariableSelectionNetwork(hidden_size=hidden_size,\n",
    "                                            num_inputs=num_static_vars,\n",
    "                                            dropout=dropout)\n",
    "        self.context_grns = nn.ModuleList(\n",
    "                              [GRN(input_size=hidden_size,\n",
    "                                   hidden_size=hidden_size,\n",
    "                                   dropout=dropout) for _ in range(4)])\n",
    "\n",
    "    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n",
    "        variable_ctx, sparse_weights = self.vsn(x)\n",
    "\n",
    "        # Context vectors:\n",
    "        # variable selection context\n",
    "        # enrichment context\n",
    "        # state_c context\n",
    "        # state_h context\n",
    "        cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns)\n",
    "\n",
    "        return cs, ce, ch, cc"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Temporal Covariate Encoder\n",
    "\n",
    "TemporalCovariateEncoder encodes the embeddings $\\mathbf{E}^{(h)}, \\mathbf{E}^{(f)}$ and contexts  $(c_{h}, c_{c})$ with an LSTM.\n",
    "\n",
    "\\begin{align}\n",
    "\\tilde{\\mathbf{E}}^{(h)}_{[:t]} & = \\textrm{VSN}(\\mathbf{E}^{(h)}_{[:t]}, c_{s}) \\\\\n",
    "\\tilde{\\mathbf{E}}^{(h)}_{[:t]} &= \\mathrm{LSTM}(\\tilde{\\mathbf{E}}^{(h)}_{[:t]}, (c_{h}, c_{c})) \\\\\n",
    "h_{[:t]} &= \\mathrm{Gate}(\\mathrm{LayerNorm}(\\tilde{\\mathbf{E}}^{(h)}_{[:t]}))\n",
    "\\end{align}\n",
    "\n",
    "An analogous process is repeated for the future data, with the main difference that $\\mathbf{E}^{(f)}$ contains the future available information.\n",
    "\n",
    "\\begin{align}\n",
    "\\tilde{\\mathbf{E}}^{(f)}_{[t+1:t+h]} & = \\textrm{VSN}(\\mathbf{E}^{(h)}_{t+1:t+H}, \\mathbf{E}^{(f)}_{t+1:t+H}, c_{s}) \\\\\n",
    "\\tilde{\\mathbf{E}}^{(f)}_{[t+1:t+h]} &= \\mathrm{LSTM}(\\tilde{\\mathbf{E}}^{(h)}_{[t+1:t+h]}, (c_{h}, c_{c})) \\\\\n",
    "h_{[t+1:t+H]} &= \\mathrm{Gate}(\\mathrm{LayerNorm}(\\tilde{\\mathbf{E}}^{(f)}_{[t+1:t+h]}))\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class TemporalCovariateEncoder(nn.Module):\n",
    "    def __init__(self, hidden_size, \n",
    "                 num_historic_vars, num_future_vars, dropout):\n",
    "        super(TemporalCovariateEncoder, self).__init__()\n",
    "\n",
    "        self.history_vsn = VariableSelectionNetwork(\n",
    "                                       hidden_size=hidden_size,\n",
    "                                       num_inputs=num_historic_vars,\n",
    "                                       dropout=dropout)\n",
    "        self.history_encoder = nn.LSTM(input_size=hidden_size,\n",
    "                                       hidden_size=hidden_size,\n",
    "                                       batch_first=True)\n",
    "        \n",
    "        self.future_vsn = VariableSelectionNetwork(hidden_size=hidden_size,\n",
    "                                                   num_inputs=num_future_vars,\n",
    "                                                   dropout=dropout)\n",
    "        self.future_encoder = nn.LSTM(input_size=hidden_size,\n",
    "                                      hidden_size=hidden_size,\n",
    "                                      batch_first=True)\n",
    "        \n",
    "        # Shared Gated-Skip Connection\n",
    "        self.input_gate = GLU(hidden_size, hidden_size)\n",
    "        self.input_gate_ln = LayerNorm(hidden_size, eps=1e-3)\n",
    "    \n",
    "    def forward(self, historical_inputs, future_inputs, cs, ch, cc):\n",
    "        # [N,X_in,L] -> [N,hidden_size,L]\n",
    "        historical_features, _ = self.history_vsn(historical_inputs, cs)\n",
    "        history, state = self.history_encoder(historical_features, (ch, cc))\n",
    "\n",
    "        future_features, _ = self.future_vsn(future_inputs, cs)\n",
    "        future, _ = self.future_encoder(future_features, state)\n",
    "        #torch.cuda.synchronize() # this call gives prf boost for unknown reasons\n",
    "\n",
    "        input_embedding = torch.cat([historical_features, future_features], dim=1)\n",
    "        temporal_features = torch.cat([history, future], dim=1)\n",
    "        temporal_features = self.input_gate(temporal_features)\n",
    "        temporal_features = temporal_features + input_embedding\n",
    "        temporal_features = self.input_gate_ln(temporal_features)      \n",
    "        return temporal_features"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Temporal Fusion Decoder\n",
    "\n",
    "The TemporalFusionDecoder enriches the LSTM's outputs with $c_{e}$ and then uses an attention layer, and multi-step adapter.\n",
    "\\begin{align}\n",
    "h_{[t+1:t+H]} &= \\mathrm{MultiHeadAttention}(h_{[:t]}, h_{[t+1:t+H]}, c_{e}) \\\\\n",
    "h_{[t+1:t+H]} &= \\mathrm{Gate}(\\mathrm{LayerNorm}(h_{[t+1:t+H]}) \\\\\n",
    "h_{[t+1:t+H]} &= \\mathrm{Gate}(\\mathrm{LayerNorm}(\\mathrm{GRN}(h_{[t+1:t+H]})) \\\\\n",
    "\\hat{\\mathbf{y}}^{(q)}_{[t+1:t+H]} &= \\mathrm{MLP}(h_{[t+1:t+H]})\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "class TemporalFusionDecoder(nn.Module):\n",
    "    def __init__(self, \n",
    "                 n_head, hidden_size, \n",
    "                 example_length, encoder_length,\n",
    "                 attn_dropout, dropout):\n",
    "        super(TemporalFusionDecoder, self).__init__()\n",
    "        self.encoder_length = encoder_length\n",
    "        \n",
    "        #------------- Encoder-Decoder Attention --------------#\n",
    "        self.enrichment_grn = GRN(input_size=hidden_size,\n",
    "                                  hidden_size=hidden_size,\n",
    "                                  context_hidden_size=hidden_size, \n",
    "                                  dropout=dropout)\n",
    "        self.attention = InterpretableMultiHeadAttention(\n",
    "                                       n_head=n_head,\n",
    "                                       hidden_size=hidden_size,\n",
    "                                       example_length=example_length,\n",
    "                                       attn_dropout=attn_dropout,\n",
    "                                       dropout=dropout)\n",
    "        self.attention_gate = GLU(hidden_size, hidden_size)\n",
    "        self.attention_ln = LayerNorm(normalized_shape=hidden_size, eps=1e-3)\n",
    "\n",
    "        self.positionwise_grn = GRN(input_size=hidden_size,\n",
    "                                    hidden_size=hidden_size,\n",
    "                                    dropout=dropout)\n",
    "        \n",
    "        #---------------------- Decoder -----------------------#\n",
    "        self.decoder_gate = GLU(hidden_size, hidden_size)\n",
    "        self.decoder_ln = LayerNorm(normalized_shape=hidden_size, eps=1e-3)\n",
    "        \n",
    "    \n",
    "    def forward(self, temporal_features, ce):\n",
    "        #------------- Encoder-Decoder Attention --------------#\n",
    "        # Static enrichment\n",
    "        enriched = self.enrichment_grn(temporal_features, c=ce)\n",
    "\n",
    "        # Temporal self attention\n",
    "        x, _ = self.attention(enriched, mask_future_timesteps=True)\n",
    "\n",
    "        # Don't compute historical quantiles\n",
    "        x = x[:, self.encoder_length:, :]\n",
    "        temporal_features = temporal_features[:, self.encoder_length:, :]\n",
    "        enriched = enriched[:, self.encoder_length:, :]\n",
    "\n",
    "        x = self.attention_gate(x)\n",
    "        x = x + enriched\n",
    "        x = self.attention_ln(x)\n",
    "\n",
    "        # Position-wise feed-forward\n",
    "        x = self.positionwise_grn(x)\n",
    "\n",
    "        #---------------------- Decoder ----------------------#\n",
    "        # Final skip connection\n",
    "        x = self.decoder_gate(x)\n",
    "        x = x + temporal_features\n",
    "        x = self.decoder_ln(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. TFT methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TFT(BaseWindows):\n",
    "    \"\"\" TFT\n",
    "\n",
    "    The Temporal Fusion Transformer architecture (TFT) is an Sequence-to-Sequence \n",
    "    model that combines static, historic and future available data to predict an\n",
    "    univariate target. The method combines gating layers, an LSTM recurrent encoder, \n",
    "    with and interpretable multi-head attention layer and a multi-step forecasting \n",
    "    strategy decoder.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `h`: int, Forecast horizon. <br>\n",
    "    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].<br>\n",
    "    `stat_exog_list`: str list, static continuous columns.<br>\n",
    "    `hist_exog_list`: str list, historic continuous columns.<br>\n",
    "    `futr_exog_list`: str list, future continuous columns.<br>\n",
    "    `hidden_size`: int, units of embeddings and encoders.<br>\n",
    "    `dropout`: float (0, 1), dropout of inputs VSNs.<br>\n",
    "    `n_head`: int=4, number of attention heads in temporal fusion decoder.<br>\n",
    "    `attn_dropout`: float (0, 1), dropout of fusion decoder's attention layer.<br>\n",
    "    `shared_weights`: bool, If True, all blocks within each stack will share parameters. <br>\n",
    "    `activation`: str, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid'].<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int, number of different series in each batch.<br>\n",
    "    `windows_batch_size`: int=None, windows sampled from rolled data, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch.<br>\n",
    "    `step_size`: int=1, step size between each window of temporal data.<br>\n",
    "    `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int, random seed initialization for replicability.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>    \n",
    "\n",
    "    **References:**<br>\n",
    "    - [Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister, \n",
    "    \"Temporal Fusion Transformers for interpretable multi-horizon time series forecasting\"](https://www.sciencedirect.com/science/article/pii/S0169207021000637)\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'windows'\n",
    "    \n",
    "    def __init__(self,\n",
    "                 h,\n",
    "                 input_size,\n",
    "                 tgt_size: int = 1,\n",
    "                 stat_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 futr_exog_list = None,\n",
    "                 hidden_size: int = 128,\n",
    "                 n_head: int = 4,\n",
    "                 attn_dropout: float = 0.0,\n",
    "                 dropout: float = 0.1,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 1000,\n",
    "                 learning_rate: float = 1e-3,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size: int = 1024,\n",
    "                 inference_windows_batch_size: int = 1024,\n",
    "                 start_padding_enabled = False,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'robust',\n",
    "                 num_workers_loader = 0,\n",
    "                 drop_last_loader = False,\n",
    "                 random_seed: int = 1,\n",
    "                 **trainer_kwargs\n",
    "                 ):\n",
    "\n",
    "        # Inherit BaseWindows class\n",
    "        super(TFT, self).__init__(h=h,\n",
    "                                  input_size=input_size,\n",
    "                                  loss=loss,\n",
    "                                  valid_loss=valid_loss,\n",
    "                                  max_steps=max_steps,\n",
    "                                  learning_rate=learning_rate,\n",
    "                                  num_lr_decays=num_lr_decays,\n",
    "                                  early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                  val_check_steps=val_check_steps,\n",
    "                                  batch_size=batch_size,\n",
    "                                  valid_batch_size=valid_batch_size,\n",
    "                                  windows_batch_size=windows_batch_size,\n",
    "                                  inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                  start_padding_enabled=start_padding_enabled,\n",
    "                                  step_size=step_size,\n",
    "                                  scaler_type=scaler_type,\n",
    "                                  num_workers_loader=num_workers_loader,\n",
    "                                  drop_last_loader=drop_last_loader,\n",
    "                                  random_seed=random_seed,\n",
    "                                  **trainer_kwargs)\n",
    "        self.example_length = input_size + h\n",
    "\n",
    "        # Parse lists hyperparameters\n",
    "        self.stat_exog_list = [] if stat_exog_list is None else stat_exog_list\n",
    "        self.hist_exog_list = [] if hist_exog_list is None else hist_exog_list\n",
    "        self.futr_exog_list = [] if futr_exog_list is None else futr_exog_list\n",
    "\n",
    "        stat_input_size = len(self.stat_exog_list)\n",
    "        futr_input_size = max(len(self.futr_exog_list), 1)\n",
    "        hist_input_size = len(self.hist_exog_list)\n",
    "        num_historic_vars = futr_input_size + hist_input_size + tgt_size\n",
    "\n",
    "        #------------------------------- Encoders -----------------------------#\n",
    "        self.embedding = TFTEmbedding(hidden_size=hidden_size,\n",
    "                                      stat_input_size=stat_input_size,\n",
    "                                      futr_input_size=futr_input_size,\n",
    "                                      hist_input_size=hist_input_size,\n",
    "                                      tgt_size=tgt_size)\n",
    "        \n",
    "        self.static_encoder = StaticCovariateEncoder(\n",
    "                                      hidden_size=hidden_size,\n",
    "                                      num_static_vars=stat_input_size,\n",
    "                                      dropout=dropout)\n",
    "\n",
    "        self.temporal_encoder = TemporalCovariateEncoder(\n",
    "                                      hidden_size=hidden_size,\n",
    "                                      num_historic_vars=num_historic_vars,\n",
    "                                      num_future_vars=futr_input_size,\n",
    "                                      dropout=dropout)\n",
    "\n",
    "        #------------------------------ Decoders -----------------------------#\n",
    "        self.temporal_fusion_decoder = TemporalFusionDecoder(\n",
    "                                      n_head=n_head,\n",
    "                                      hidden_size=hidden_size,\n",
    "                                      example_length=self.example_length,\n",
    "                                      encoder_length=self.input_size,\n",
    "                                      attn_dropout=attn_dropout,\n",
    "                                      dropout=dropout)\n",
    "\n",
    "        # Adapter with Loss dependent dimensions\n",
    "        self.output_adapter = nn.Linear(in_features=hidden_size,\n",
    "                                        out_features=self.loss.outputsize_multiplier)\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "\n",
    "        # Parsiw windows_batch\n",
    "        y_insample = windows_batch['insample_y'][:,:, None] # <- [B,T,1]\n",
    "        futr_exog  = windows_batch['futr_exog']\n",
    "        hist_exog  = windows_batch['hist_exog']\n",
    "        stat_exog  = windows_batch['stat_exog']\n",
    "\n",
    "        if futr_exog is None:\n",
    "            futr_exog = y_insample[:, [-1]]\n",
    "            futr_exog = futr_exog.repeat(1, self.example_length, 1)\n",
    "\n",
    "        s_inp, k_inp, o_inp, t_observed_tgt = self.embedding(target_inp=y_insample, \n",
    "                                                             hist_exog=hist_exog,\n",
    "                                                             futr_exog=futr_exog,\n",
    "                                                             stat_exog=stat_exog)\n",
    "\n",
    "        #-------------------------------- Inputs ------------------------------#\n",
    "        # Static context\n",
    "        if s_inp is not None:\n",
    "            cs, ce, ch, cc = self.static_encoder(s_inp)\n",
    "            ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # LSTM initial states\n",
    "        else:\n",
    "            # If None add zeros\n",
    "            batch_size, example_length, target_size, hidden_size = t_observed_tgt.shape\n",
    "            cs = torch.zeros(size=(batch_size, hidden_size)).to(y_insample.device)\n",
    "            ce = torch.zeros(size=(batch_size, hidden_size)).to(y_insample.device)\n",
    "            ch = torch.zeros(size=(1, batch_size, hidden_size)).to(y_insample.device)\n",
    "            cc = torch.zeros(size=(1, batch_size, hidden_size)).to(y_insample.device)\n",
    "\n",
    "        # Historical inputs\n",
    "        _historical_inputs = [k_inp[:,:self.input_size,:],\n",
    "                              t_observed_tgt[:,:self.input_size,:]]\n",
    "        if o_inp is not None:\n",
    "            _historical_inputs.insert(0,o_inp[:,:self.input_size,:])\n",
    "        historical_inputs = torch.cat(_historical_inputs, dim=-2)\n",
    "\n",
    "        # Future inputs\n",
    "        future_inputs = k_inp[:, self.input_size:]\n",
    "\n",
    "        #---------------------------- Encode/Decode ---------------------------#\n",
    "        # Embeddings + VSN + LSTM encoders\n",
    "        temporal_features = self.temporal_encoder(historical_inputs=historical_inputs,\n",
    "                                                  future_inputs=future_inputs,\n",
    "                                                  cs=cs, ch=ch, cc=cc)\n",
    "\n",
    "        # Static enrichment, Attention and decoders\n",
    "        temporal_features = self.temporal_fusion_decoder(temporal_features=temporal_features,\n",
    "                                                         ce=ce)\n",
    "\n",
    "        # Adapt output to loss\n",
    "        y_hat = self.output_adapter(temporal_features)\n",
    "        y_hat = self.loss.domain_map(y_hat)\n",
    "\n",
    "        return y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TFT, title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TFT.fit, name='TFT.fit', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TFT.predict, name='TFT.predict', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss, GMM, PMM\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "#from neuralforecast.models import TFT\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss, GMM, PMM\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic\n",
    "\n",
    "#AirPassengersPanel['y'] = AirPassengersPanel['y'] + 10\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[TFT(h=12, input_size=48,\n",
    "                hidden_size=20,\n",
    "                #loss=DistributionLoss(distribution='Poisson', level=[80, 90]),\n",
    "                #loss=DistributionLoss(distribution='Normal', level=[80, 90]),\n",
    "                loss=DistributionLoss(distribution='StudentT', level=[80, 90]),\n",
    "                learning_rate=0.005,\n",
    "                stat_exog_list=['airline1'],\n",
    "                #futr_exog_list=['y_[lag12]'],\n",
    "                hist_exog_list=['trend'],\n",
    "                max_steps=500,\n",
    "                val_check_steps=10,\n",
    "                early_stop_patience_steps=10,\n",
    "                scaler_type='robust',\n",
    "                windows_batch_size=None,\n",
    "                enable_progress_bar=True),\n",
    "    ],\n",
    "    freq='M'\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
    "Y_hat_df = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "# Plot quantile predictions\n",
    "Y_hat_df = Y_hat_df.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "plt.plot(plot_df['ds'], plot_df['TFT'], c='purple', label='mean')\n",
    "plt.plot(plot_df['ds'], plot_df['TFT-median'], c='blue', label='median')\n",
    "plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                 y1=plot_df['TFT-lo-90'][-12:].values, \n",
    "                 y2=plot_df['TFT-hi-90'][-12:].values,\n",
    "                 alpha=0.4, label='level 90')\n",
    "plt.legend()\n",
    "plt.grid()\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
